import tensorflow as tf
from tensorflow.python.framework import ops
import sys
import os
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(BASE_DIR)
grouping_module=tf.load_op_library(os.path.join(BASE_DIR, 'out.so'))
def query_ball_point(radius, nsample, xyz1, xyz2):
    '''
    Input:
        radius: float32, ball search radius
        nsample: int32, number of points selected in each ball region
        xyz1: (batch_size, ndataset, 3) float32 array, input points
        xyz2: (batch_size, npoint, 3) float32 array, query points
    Output:
        idx: (batch_size, npoint, nsample) int32 array, indices to input points
        pts_cnt: (batch_size, npoint) int32 array, number of unique points in each local region
    '''
    #return grouping_module.query_ball_point(radius, nsample, xyz1, xyz2)
    return grouping_module.query_ball_point(xyz1, xyz2, radius, nsample)




def _query_ball(query_radius, num_max_neighbor,select_mask, all_pos, all_feature):
    '''
    build the feature of neighbor particles within a radius of select particles in all particles
    :param query_radius: Float. the radius of query ball
    :param num_max_neighbor: Int. the max num of neighbor particles of a particle
    :param select_mask: A bool array. The mask of select particles. So the select particles are 'all[select_mask]',
    which is [F, 3]
    :param all_pos: Float array, [N, 3]. The coordinate of all particles
    :param all_feature: Float array, [N, 4]. The feature of all particles
    :return: Int array, [F, max, 3+4]. The neighbor relative_pos+feature of each select particle in all particles.
    The relative_pos means {the pos of neighbor particles - the pos} of a select particle.
    ATTENTION, a particle cannot be the neighbor of itself. If the num of actual neighbors is out of max,
    then choose the nearest particles. If the num of actual neighbors is less than 'max',
    then fill the relative_pos domain [0, 5*r, 0], and fill the feature domain [0,0,0,0].
    '''
    # select_pos = (np.random.random((32,3)).astype('float32'))
    # select_feature =  (np.random.random((32,4)).astype('float32'))

    select_pos =all_pos[select_mask,:] 
    select_feature =  all_feature[select_mask,:] 
    
    neighbor_feature = grouping_module.my_query_ball_point \
        (select_pos,select_feature,all_pos,all_feature,    \
        query_radius, num_max_neighbor),                   
    return neighbor_feature


if __name__=='__main__':
    knn=True
    import numpy as np
    import time
    np.random.seed(100)
    pos = (np.random.random((512,3)).astype('float32'))
    feat = (np.random.random((512,4)).astype('float32'))
    select_mask = np.random.random((512)).astype('bool')
    
    with tf.device('/gpu:0'):
        radius = 0.1
        nsample = 64
        out = _query_ball(radius, nsample,select_mask,pos,feat)
    
    print(out)